CNN for MNIST

Still WIP


In [32]:
import tensorflow as tf
import numpy as np

In [33]:
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)


Extracting MNIST_data/train-images-idx3-ubyte.gz
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz

In [34]:
# each layer consists of a f(xW + b) where f is a ReLU activation function
# For ReLU units, a study published in 2015 by He et al. demonstrates that the variance
# of weights in a network should be 2/n_in , where n_in is the number inputs coming into in the neuron.
def layer(input, weight_shape, bias_shape):
    weight_stddev = (2.0 / weight_shape[0])**0.5 # variance=2/n_in as explained above
    w_init = tf.random_normal_initializer(stddev=weight_stddev)
    bias_init = tf.constant_initializer(value=0)
    W = tf.get_variable("W", weight_shape, initializer=w_init)
    b = tf.get_variable("b", bias_shape, initializer=bias_init)
    return tf.nn.relu(tf.matmul(input, W) + b)

In [37]:
# inference is the way we pass our input x through the nn of 2 hidden layers and 1 output layer
def inference(x, keep_prob):
    x = tf.reshape(x, shape=[-1, 28, 28, 1])
    with tf.variable_scope("conv_1"):
        conv_1 = conv2d(x, [5, 5, 1, 32], [32])
        pool_1 = max_pool(conv_1)
    with tf.variable_scope("conv_2"):
        conv_2 = conv2d(pool_1, [5, 5, 32, 64], [64])
        pool_2 = max_pool(conv_2)
    with tf.variable_scope("fc"):
        pool_2_flat = tf.reshape(pool_2, [-1, 7 * 7 * 64])
        fc_1 = layer(pool_2_flat, [7 * 7 * 64, 1024], [1024])
        # apply dropout
        fc_1_drop = tf.nn.dropout(fc_1, keep_prob)
    with tf.variable_scope("output"):
        output = layer(fc_1_drop, [1024, 10], [10])
    return output

In [38]:
# softmax moved form inference to loss function for performance improvement
def loss(output, y):
    xentropy = tf.nn.softmax_cross_entropy_with_logits(logits=output, labels=y)
    loss = tf.reduce_mean(xentropy)
    return loss

In [40]:
def training(cost, global_step):
    tf.summary.tensor_summary("cost", cost)
    optimizer = tf.train.AdamOptimizer(learning_rate)
    train_op = optimizer.minimize(cost, global_step=global_step)
    return train_op

In [41]:
def evaluate(output, y):
    # compare indices of predicted class, if equal (correct classification) set 1 otherwise 0
    correct_prediction = tf.equal(tf.argmax(output, 1), tf.argmax(y, 1))
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
    return accuracy

CNN functions


In [42]:
def conv2d(input, weight_shape, bias_shape):
    in_ = weight_shape[0] * weight_shape[1] * weight_shape[2]
    weight_init = tf.random_normal_initializer(stddev=(2.0 / in_)**0.5) # see above for explanation
    W = tf.get_variable("W", weight_shape, initializer=weight_init)
    bias_init = tf.constant_initializer(value=0)
    b = tf.get_variable("b", bias_shape, initializer=bias_init)
    conv_out = tf.nn.conv2d(input, W, strides=[1, 1, 1, 1], padding='SAME') # keep the width and height constant between input and output tensors
    return tf.nn.relu(tf.nn.bias_add(conv_out, b))

In [43]:
def max_pool(input, k=2):
    return tf.nn.max_pool(
        input, ksize=[1, k, k, 1], strides=[1, k, k, 1], padding='SAME')

In [44]:
# Parameters
learning_rate = 0.01
training_epochs = 100
batch_size = 100
display_step = 1
keep_prop = 0.5

In [46]:
from tqdm import tqdm

# program flow
with tf.Graph().as_default():
    # mnist data image of shape 28*28=784
    x = tf.placeholder("float", [None, 784])
    # 0-9 digits recognition => 10 classes
    y = tf.placeholder("float", [None, 10])
    output = inference(x, keep_prop)
    cost = loss(output, y)
    global_step = tf.Variable(0, name='global_step', trainable=False)
    train_op = training(cost, global_step)
    eval_op = evaluate(output, y)

    # tf.merge_all_summaries in order to collect all summary statistics
    # use a tf.train.SummaryWriter to write the log to disk.
    summary_op = tf.summary.merge_all()
    saver = tf.train.Saver()
    sess = tf.Session()
    # write to tensorboard graph api
    summary_writer = tf.summary.FileWriter(
        "logistic_logs/", graph=sess.graph)
    init_op = tf.global_variables_initializer()
    sess.run(init_op)
    # training cycle
    for epoch in tqdm(range(training_epochs)):
        avg_cost = 0.
        total_batch = int(mnist.train.num_examples / batch_size)

        # Loop over all batches
        for i in range(total_batch):
            mbatch_x, mbatch_y = mnist.train.next_batch(batch_size)

            # Fit training using batch data
            feed_dict = {x: mbatch_x, y: mbatch_y}
            sess.run(train_op, feed_dict=feed_dict)

            # Compute average loss
            minibatch_cost = sess.run(cost, feed_dict=feed_dict)
            avg_cost += minibatch_cost / total_batch
        # Display logs per epoch step
        if epoch % display_step == 0:
            val_feed_dict = {
                x: mnist.validation.images,
                y: mnist.validation.labels
            }
            accuracy = sess.run(eval_op, feed_dict=val_feed_dict)
            print("Validation Error in epoch %s: %.11f" % (epoch, 1 - accuracy))
            summary_str = sess.run(summary_op, feed_dict=feed_dict)
            summary_writer.add_summary(summary_str, sess.run(global_step))
            saver.save(
                sess,
                "logistic_logs/model-checkpoint",
                global_step=global_step)

    test_feed_dict = {x: mnist.test.images, y: mnist.test.labels}
    accuracy = sess.run(eval_op, feed_dict=test_feed_dict)
    print("Test Accuracy:", accuracy)


  0%|          | 0/100 [00:00<?, ?it/s]
---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
<ipython-input-46-3de98936cdfb> in <module>()
     34             # Fit training using batch data
     35             feed_dict = {x: mbatch_x, y: mbatch_y}
---> 36             sess.run(train_op, feed_dict=feed_dict)
     37 
     38             # Compute average loss

/usr/local/lib/python3.6/site-packages/tensorflow/python/client/session.py in run(self, fetches, feed_dict, options, run_metadata)
    893     try:
    894       result = self._run(None, fetches, feed_dict, options_ptr,
--> 895                          run_metadata_ptr)
    896       if run_metadata:
    897         proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)

/usr/local/lib/python3.6/site-packages/tensorflow/python/client/session.py in _run(self, handle, fetches, feed_dict, options, run_metadata)
   1122     if final_fetches or final_targets or (handle and feed_dict_tensor):
   1123       results = self._do_run(handle, final_targets, final_fetches,
-> 1124                              feed_dict_tensor, options, run_metadata)
   1125     else:
   1126       results = []

/usr/local/lib/python3.6/site-packages/tensorflow/python/client/session.py in _do_run(self, handle, target_list, fetch_list, feed_dict, options, run_metadata)
   1319     if handle is None:
   1320       return self._do_call(_run_fn, self._session, feeds, fetches, targets,
-> 1321                            options, run_metadata)
   1322     else:
   1323       return self._do_call(_prun_fn, self._session, handle, feeds, fetches)

/usr/local/lib/python3.6/site-packages/tensorflow/python/client/session.py in _do_call(self, fn, *args)
   1325   def _do_call(self, fn, *args):
   1326     try:
-> 1327       return fn(*args)
   1328     except errors.OpError as e:
   1329       message = compat.as_text(e.message)

/usr/local/lib/python3.6/site-packages/tensorflow/python/client/session.py in _run_fn(session, feed_dict, fetch_list, target_list, options, run_metadata)
   1304           return tf_session.TF_Run(session, options,
   1305                                    feed_dict, fetch_list, target_list,
-> 1306                                    status, run_metadata)
   1307 
   1308     def _prun_fn(session, handle, feed_dict, fetch_list):

KeyboardInterrupt: